package fec

import (
	"github.com/klauspost/reedsolomon"
	"github.com/lucas-clemente/quic-go/internal/utils"
	"bytes"
	"errors"
	"log"
	"github.com/lucas-clemente/quic-go/internal/protocol"
)

var _ BlockFECScheme = &ReedSolomonFECScheme{}

var ReedSolomonNoRepairSymbolInFECGroup = errors.New("ReedSolomon FEC Scheme: impossible to recover FEC Group with no repair symbol")
var ReedSolomonInvalidNumberOfSymbols		= errors.New("ReedSolomon FEC Scheme: impossible to build FEC Scheme with a number of symbols equal to zero")

// k is the number of repair symbols generated by the FEC Scheme. It must be greater than 0
// The number of source symbols will be determined by the number
// of packets in the fecGroup in the Generate and Recover functions
func NewReedSolomonFECScheme(k byte) (*ReedSolomonFECScheme, error) {
	if k == 0 {
		return nil, ReedSolomonInvalidNumberOfSymbols
	}
	return &ReedSolomonFECScheme{k}, nil
}

type ReedSolomonFECScheme struct {
	numberOfRepairSymbols byte
}


func (f *ReedSolomonFECScheme) GetRepairSymbols(fecGroup FECContainer, numberOfSymbols uint, _ protocol.FECPayloadID) ([]*RepairSymbol, error) {
	// Make all the packets have the same size
	packets := fecGroup.GetPackets()
	maxLen := 0
	for _, packet := range packets {
		maxLen = utils.Max(maxLen, len(packet))
	}
	reedSolomonInput := make([][]byte, len(packets) + int(numberOfSymbols))
	for i, packet := range packets {
		delta := maxLen - len(packet)
		packets[i] = append(packet, bytes.Repeat([]byte{0}, delta)...)
		reedSolomonInput[i] = packets[i]
	}

	for i := len(packets) ; i < len(reedSolomonInput) ; i++ {
		reedSolomonInput[i] = make([]byte, maxLen)
	}

	// encode
	enc, err := reedsolomon.New(len(packets), int(numberOfSymbols))
	if err != nil {
		return nil, err
	}

	enc.Encode(reedSolomonInput) // won't error as the shards are of equal size

	var symbols []*RepairSymbol
	for i, symbol := range reedSolomonInput[len(packets):] {
		symbols = append(symbols, &RepairSymbol{
			SymbolNumber: byte(i),
			Data: symbol,
			Convolutional:   f.Convolutional(),
		})
	}
	return symbols, nil
}
func (f *ReedSolomonFECScheme) RecoverPackets(fecGroup *FECGroup)	([][]byte, error) {
	if fecGroup.TotalNumberOfRepairSymbols == 0 {
		return nil, ReedSolomonNoRepairSymbolInFECGroup
	}
	n, k := fecGroup.TotalNumberOfPackets, fecGroup.TotalNumberOfRepairSymbols
	enc, err := reedsolomon.New(n, k)
	if err != nil {
		return nil, err
	}
	reedSolomonInput := make([][]byte, n+k)
	var indicesToRecover []int
	// Regularize the size of the packets for the decoder
	for i, packet := range fecGroup.packets {
		if packet != nil {
			delta := len(fecGroup.RepairSymbols[0].Data) - len(packet)
			fecGroup.packets[i] = append(packet, bytes.Repeat([]byte{0}, delta)...)
			reedSolomonInput[i] = fecGroup.packets[i]
		} else {
			// this packet is missing
			indicesToRecover = append(indicesToRecover, i)
		}
	}
	if len(fecGroup.packets) < fecGroup.TotalNumberOfPackets {
		// the last packets of the fecGroup are missing
		// TODO: this should not be done by the fec scheme. The missing packets at the end of the fec groupd should be indicated as being set to nil in fecGroup.packets
		// TODO: easy to do this if we force to use a Setter to change fecGroup.TotalNumberOfPackets
		for i := len(fecGroup.packets) ; i < fecGroup.TotalNumberOfPackets ; i++ {
			indicesToRecover = append(indicesToRecover, i)
		}
	}

	for _, symbol := range fecGroup.RepairSymbols {
		log.Printf("sn = %d, len = %d, index = %d", symbol.SymbolNumber, len(reedSolomonInput), fecGroup.TotalNumberOfPackets + int(symbol.SymbolNumber))
		reedSolomonInput[fecGroup.TotalNumberOfPackets + int(symbol.SymbolNumber)] = symbol.Data
	}

	err = enc.ReconstructData(reedSolomonInput)
	if err != nil {
		return nil, err
	}


	var recoveredPackets [][]byte
	for _, i := range indicesToRecover {
		recoveredPackets = append(recoveredPackets, reedSolomonInput[i])
	}
	return recoveredPackets, nil
}
func (*ReedSolomonFECScheme) CanRecoverPackets(fecGroup *FECGroup) bool {
	retVal :=  len(fecGroup.RepairSymbols) != 0 &&
		     fecGroup.TotalNumberOfPackets != 0 &&
				 fecGroup.CurrentNumberOfPackets() < fecGroup.TotalNumberOfPackets &&		// there is nothing to recover if this is not true
				 fecGroup.CurrentNumberOfPackets() + len(fecGroup.RepairSymbols) >= fecGroup.TotalNumberOfPackets		// impossible to recover if this is not true

	if retVal {
		log.Printf("received %d packets (total %d), %d repair symbols (total %d)", len(fecGroup.packetIndexes), fecGroup.TotalNumberOfPackets, len(fecGroup.RepairSymbols), fecGroup.TotalNumberOfRepairSymbols)
	}

	return retVal
}

func (*ReedSolomonFECScheme) Convolutional() bool {
	return false
}